{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Running hyperparameter optimization on Chemprop model using RayTune"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/hpopting.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install \".[hpopt]\"\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2024-10-22 09:03:28,414\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
      "2024-10-22 09:03:28,801\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
      "2024-10-22 09:03:29,333\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import pandas as pd\n",
    "from lightning import pytorch as pl\n",
    "import ray\n",
    "from ray import tune\n",
    "from ray.train import CheckpointConfig, RunConfig, ScalingConfig\n",
    "from ray.train.lightning import (RayDDPStrategy, RayLightningEnvironment,\n",
    "                                 RayTrainReportCallback, prepare_trainer)\n",
    "from ray.train.torch import TorchTrainer\n",
    "from ray.tune.search.hyperopt import HyperOptSearch\n",
    "from ray.tune.search.optuna import OptunaSearch\n",
    "from ray.tune.schedulers import FIFOScheduler\n",
    "\n",
    "from chemprop import data, featurizers, models, nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n",
    "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n",
    "smiles_column = 'smiles' # name of the column containing SMILES strings\n",
    "target_columns = ['lipo'] # list of names of the columns containing targets\n",
    "\n",
    "hpopt_save_dir = Path.cwd() / \"hpopt\" # directory to save hyperopt results\n",
    "hpopt_save_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>lipo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14</td>\n",
       "      <td>3.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...</td>\n",
       "      <td>-1.18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl</td>\n",
       "      <td>3.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...</td>\n",
       "      <td>3.37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...</td>\n",
       "      <td>3.10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...</td>\n",
       "      <td>2.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...</td>\n",
       "      <td>2.04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...</td>\n",
       "      <td>4.49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>COc1ccc(Cc2c(N)n[nH]c2N)cc1</td>\n",
       "      <td>0.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...</td>\n",
       "      <td>2.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles  lipo\n",
       "0             Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14  3.54\n",
       "1   COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n",
       "2              COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl  3.69\n",
       "3   OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...  3.37\n",
       "4   Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...  3.10\n",
       "..                                                ...   ...\n",
       "95  CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...  2.20\n",
       "96  CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...  2.04\n",
       "97  CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...  4.49\n",
       "98                        COc1ccc(Cc2c(N)n[nH]c2N)cc1  0.20\n",
       "99  CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...  2.00\n",
       "\n",
       "[100 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_input = pd.read_csv(input_path)\n",
    "df_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "smis = df_input.loc[:, smiles_column].values\n",
    "ys = df_input.loc[:, target_columns].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make data points, splits, and datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n",
    "\n",
    "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n",
    "scaler = train_dset.normalize_targets()\n",
    "\n",
    "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n",
    "val_dset.normalize_targets(scaler)\n",
    "\n",
    "test_dset = data.MoleculeDataset(test_data[0], featurizer)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define helper function to train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(config, train_dset, val_dset, num_workers, scaler):\n",
    "\n",
    "    # config is a dictionary containing hyperparameters used for the trial\n",
    "    depth = int(config[\"depth\"])\n",
    "    ffn_hidden_dim = int(config[\"ffn_hidden_dim\"])\n",
    "    ffn_num_layers = int(config[\"ffn_num_layers\"])\n",
    "    message_hidden_dim = int(config[\"message_hidden_dim\"])\n",
    "\n",
    "    train_loader = data.build_dataloader(train_dset, num_workers=num_workers, shuffle=True)\n",
    "    val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n",
    "\n",
    "    mp = nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth)\n",
    "    agg = nn.MeanAggregation()\n",
    "    output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)\n",
    "    ffn = nn.RegressionFFN(output_transform=output_transform, input_dim=message_hidden_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers)\n",
    "    batch_norm = True\n",
    "    metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]\n",
    "    model = models.MPNN(mp, agg, ffn, batch_norm, metric_list)\n",
    "\n",
    "    trainer = pl.Trainer(\n",
    "        accelerator=\"auto\",\n",
    "        devices=1,\n",
    "        max_epochs=20, # number of epochs to train for\n",
    "        # below are needed for Ray and Lightning integration\n",
    "        strategy=RayDDPStrategy(),\n",
    "        callbacks=[RayTrainReportCallback()],\n",
    "        plugins=[RayLightningEnvironment()],\n",
    "    )\n",
    "\n",
    "    trainer = prepare_trainer(trainer)\n",
    "    trainer.fit(model, train_loader, val_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define parameter search space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "search_space = {\n",
    "    \"depth\": tune.qrandint(lower=2, upper=6, q=1),\n",
    "    \"ffn_hidden_dim\": tune.qrandint(lower=300, upper=2400, q=100),\n",
    "    \"ffn_num_layers\": tune.qrandint(lower=1, upper=3, q=1),\n",
    "    \"message_hidden_dim\": tune.qrandint(lower=300, upper=2400, q=100),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"tuneStatus\">\n",
       "  <div style=\"display: flex;flex-direction: row\">\n",
       "    <div style=\"display: flex;flex-direction: column;\">\n",
       "      <h3>Tune Status</h3>\n",
       "      <table>\n",
       "<tbody>\n",
       "<tr><td>Current time:</td><td>2024-10-22 09:05:01</td></tr>\n",
       "<tr><td>Running for: </td><td>00:01:23.70        </td></tr>\n",
       "<tr><td>Memory:      </td><td>10.9/15.3 GiB      </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    </div>\n",
       "    <div class=\"vDivider\"></div>\n",
       "    <div class=\"systemInfo\">\n",
       "      <h3>System Info</h3>\n",
       "      Using FIFO scheduling algorithm.<br>Logical resource usage: 2.0/12 CPUs, 0/0 GPUs\n",
       "    </div>\n",
       "    \n",
       "  </div>\n",
       "  <div class=\"hDivider\"></div>\n",
       "  <div class=\"trialStatus\">\n",
       "    <h3>Trial Status</h3>\n",
       "    <table>\n",
       "<thead>\n",
       "<tr><th>Trial name           </th><th>status    </th><th>loc                 </th><th style=\"text-align: right;\">  train_loop_config/de\n",
       "pth</th><th style=\"text-align: right;\">     train_loop_config/ff\n",
       "n_hidden_dim</th><th style=\"text-align: right;\">  train_loop_config/ff\n",
       "n_num_layers</th><th style=\"text-align: right;\">    train_loop_config/me\n",
       "ssage_hidden_dim</th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">  train_loss</th><th style=\"text-align: right;\">  train_loss_step</th><th style=\"text-align: right;\">  val/rmse</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>TorchTrainer_f1a6e41a</td><td>TERMINATED</td><td>172.31.231.162:24873</td><td style=\"text-align: right;\">2</td><td style=\"text-align: right;\">2000</td><td style=\"text-align: right;\">2</td><td style=\"text-align: right;\">500</td><td style=\"text-align: right;\">    20</td><td style=\"text-align: right;\">         49.8815</td><td style=\"text-align: right;\">   0.0990423</td><td style=\"text-align: right;\">         0.168217</td><td style=\"text-align: right;\">  0.861368</td></tr>\n",
       "<tr><td>TorchTrainer_d775c15d</td><td>TERMINATED</td><td>172.31.231.162:24953</td><td style=\"text-align: right;\">2</td><td style=\"text-align: right;\">2200</td><td style=\"text-align: right;\">2</td><td style=\"text-align: right;\">400</td><td style=\"text-align: right;\">    20</td><td style=\"text-align: right;\">         56.6533</td><td style=\"text-align: right;\">   0.069695 </td><td style=\"text-align: right;\">         0.119898</td><td style=\"text-align: right;\">  0.90258 </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "  </div>\n",
       "</div>\n",
       "<style>\n",
       ".tuneStatus {\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".tuneStatus .systemInfo {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       ".tuneStatus .trialStatus {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".tuneStatus .hDivider {\n",
       "  border-bottom-width: var(--jp-border-width);\n",
       "  border-bottom-color: var(--jp-border-color0);\n",
       "  border-bottom-style: solid;\n",
       "}\n",
       ".tuneStatus .vDivider {\n",
       "  border-left-width: var(--jp-border-width);\n",
       "  border-left-color: var(--jp-border-color0);\n",
       "  border-left-style: solid;\n",
       "  margin: 0.5em 1em 0.5em 1em;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Setting up process group for: env:// [rank=0, world_size=1]\n",
      "\u001b[36m(TorchTrainer pid=24873)\u001b[0m Started distributed worker processes: \n",
      "\u001b[36m(TorchTrainer pid=24873)\u001b[0m - (ip=172.31.231.162, pid=24952) world_rank=0, local_rank=0, node_rank=0\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m GPU available: False, used: False\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m TPU available: False, using: 0 TPU cores\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\n",
      "Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Loading `train_dataloader` to estimate number of stepping batches.\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m   | Name            | Type               | Params | Mode \n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m ---------------------------------------------------------------\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 0 | message_passing | BondMessagePassing | 579 K  | train\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 1 | agg             | MeanAggregation    | 0      | train\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 2 | bn              | BatchNorm1d        | 1.0 K  | train\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 3 | predictor       | RegressionFFN      | 5.0 M  | train\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 4 | X_d_transform   | Identity           | 0      | train\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 5 | metrics         | ModuleList         | 0      | train\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m ---------------------------------------------------------------\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 5.6 M     Trainable params\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 0         Non-trainable params\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 5.6 M     Total params\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 22.346    Total estimated model params size (MB)\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 27        Modules in train mode\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m 0         Modules in eval mode\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:363: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0:  50%|█████     | 1/2 [00:00<00:00,  1.12it/s, v_num=0, train_loss_step=0.987]\n",
      "Epoch 0: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s, v_num=0, train_loss_step=1.040]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 14.60it/s]\u001b[A\n",
      "Epoch 0: 100%|██████████| 2/2 [00:01<00:00,  1.67it/s, v_num=0, train_loss_step=1.040, val_loss=0.848]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000000)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 2/2 [00:01<00:00,  1.26it/s, v_num=0, train_loss_step=1.040, val_loss=0.848, train_loss_epoch=0.997]\n",
      "Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=1.040, val_loss=0.848, train_loss_epoch=0.997]        \n",
      "Epoch 1:  50%|█████     | 1/2 [00:00<00:00,  2.22it/s, v_num=0, train_loss_step=0.984, val_loss=0.848, train_loss_epoch=0.997]\n",
      "Epoch 1: 100%|██████████| 2/2 [00:00<00:00,  3.32it/s, v_num=0, train_loss_step=0.406, val_loss=0.848, train_loss_epoch=0.997]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 18.53it/s]\u001b[A\n",
      "Epoch 1: 100%|██████████| 2/2 [00:00<00:00,  2.97it/s, v_num=0, train_loss_step=0.406, val_loss=0.904, train_loss_epoch=0.997]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:05,874\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000001)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████| 2/2 [00:01<00:00,  1.90it/s, v_num=0, train_loss_step=0.406, val_loss=0.904, train_loss_epoch=0.869]\n",
      "Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.406, val_loss=0.904, train_loss_epoch=0.869]        \n",
      "Epoch 2:  50%|█████     | 1/2 [00:00<00:00,  1.15it/s, v_num=0, train_loss_step=1.190, val_loss=0.904, train_loss_epoch=0.869]\n",
      "Epoch 2: 100%|██████████| 2/2 [00:01<00:00,  1.81it/s, v_num=0, train_loss_step=1.290, val_loss=0.904, train_loss_epoch=0.869]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 14.01it/s]\u001b[A\n",
      "Epoch 2: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s, v_num=0, train_loss_step=1.290, val_loss=0.842, train_loss_epoch=0.869]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:07,873\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000002)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████| 2/2 [00:01<00:00,  1.29it/s, v_num=0, train_loss_step=1.290, val_loss=0.842, train_loss_epoch=1.210]\n",
      "Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=1.290, val_loss=0.842, train_loss_epoch=1.210]        \n",
      "Epoch 3:  50%|█████     | 1/2 [00:00<00:00,  1.80it/s, v_num=0, train_loss_step=0.890, val_loss=0.842, train_loss_epoch=1.210]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(TorchTrainer pid=24953)\u001b[0m Started distributed worker processes: \n",
      "\u001b[36m(TorchTrainer pid=24953)\u001b[0m - (ip=172.31.231.162, pid=25062) world_rank=0, local_rank=0, node_rank=0\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m Setting up process group for: env:// [rank=0, world_size=1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████| 2/2 [00:00<00:00,  2.44it/s, v_num=0, train_loss_step=0.749, val_loss=0.842, train_loss_epoch=1.210]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 10.81it/s]\u001b[A\n",
      "Epoch 3: 100%|██████████| 2/2 [00:00<00:00,  2.15it/s, v_num=0, train_loss_step=0.749, val_loss=0.912, train_loss_epoch=1.210]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000003)\n",
      "2024-10-22 09:04:09,291\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████| 2/2 [00:01<00:00,  1.62it/s, v_num=0, train_loss_step=0.749, val_loss=0.912, train_loss_epoch=0.861]\n",
      "Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.749, val_loss=0.912, train_loss_epoch=0.861]        \n",
      "Epoch 4:  50%|█████     | 1/2 [00:00<00:00,  1.41it/s, v_num=0, train_loss_step=0.845, val_loss=0.912, train_loss_epoch=0.861]\n",
      "Epoch 4: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s, v_num=0, train_loss_step=0.578, val_loss=0.912, train_loss_epoch=0.861]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 10.38it/s]\u001b[A\n",
      "Epoch 4: 100%|██████████| 2/2 [00:01<00:00,  1.78it/s, v_num=0, train_loss_step=0.578, val_loss=0.912, train_loss_epoch=0.861]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:11,011\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000004)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████| 2/2 [00:01<00:00,  1.31it/s, v_num=0, train_loss_step=0.578, val_loss=0.912, train_loss_epoch=0.792]\n",
      "Epoch 5:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.578, val_loss=0.912, train_loss_epoch=0.792]        \n",
      "Epoch 5:  50%|█████     | 1/2 [00:00<00:00,  1.60it/s, v_num=0, train_loss_step=0.584, val_loss=0.912, train_loss_epoch=0.792]\n",
      "Epoch 5: 100%|██████████| 2/2 [00:00<00:00,  2.58it/s, v_num=0, train_loss_step=0.751, val_loss=0.912, train_loss_epoch=0.792]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 12.17it/s]\u001b[A\n",
      "Epoch 5: 100%|██████████| 2/2 [00:00<00:00,  2.26it/s, v_num=0, train_loss_step=0.751, val_loss=0.887, train_loss_epoch=0.792]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:12,441\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000005)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5: 100%|██████████| 2/2 [00:01<00:00,  1.59it/s, v_num=0, train_loss_step=0.751, val_loss=0.887, train_loss_epoch=0.618]\n",
      "Epoch 6:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.751, val_loss=0.887, train_loss_epoch=0.618]        \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m GPU available: False, used: False\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m TPU available: False, using: 0 TPU cores\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6:  50%|█████     | 1/2 [00:00<00:00,  1.64it/s, v_num=0, train_loss_step=0.421, val_loss=0.887, train_loss_epoch=0.618]\n",
      "Epoch 6: 100%|██████████| 2/2 [00:00<00:00,  2.56it/s, v_num=0, train_loss_step=0.569, val_loss=0.887, train_loss_epoch=0.618]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 13.13it/s]\u001b[A\n",
      "Epoch 6: 100%|██████████| 2/2 [00:00<00:00,  2.28it/s, v_num=0, train_loss_step=0.569, val_loss=0.876, train_loss_epoch=0.618]\n",
      "Sanity Checking: |          | 0/? [00:00<?, ?it/s]\n",
      "Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 12.06it/s]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m Loading `train_dataloader` to estimate number of stepping batches.\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m   | Name            | Type               | Params | Mode \n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m ---------------------------------------------------------------\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 0 | message_passing | BondMessagePassing | 383 K  | train\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 1 | agg             | MeanAggregation    | 0      | train\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 2 | bn              | BatchNorm1d        | 800    | train\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 3 | predictor       | RegressionFFN      | 5.7 M  | train\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 4 | X_d_transform   | Identity           | 0      | train\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 5 | metrics         | ModuleList         | 0      | train\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m ---------------------------------------------------------------\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 6.1 M     Trainable params\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 0         Non-trainable params\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 6.1 M     Total params\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 24.444    Total estimated model params size (MB)\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 27        Modules in train mode\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m 0         Modules in eval mode\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:363: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m /home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000006)\n",
      "2024-10-22 09:04:13,968\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6: 100%|██████████| 2/2 [00:01<00:00,  1.53it/s, v_num=0, train_loss_step=0.569, val_loss=0.876, train_loss_epoch=0.450]\n",
      "Epoch 7:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.569, val_loss=0.876, train_loss_epoch=0.450]        \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:14,855\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:15,207\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7:  50%|█████     | 1/2 [00:00<00:00,  2.28it/s, v_num=0, train_loss_step=0.339, val_loss=0.876, train_loss_epoch=0.450]\u001b[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n",
      "Epoch 1: 100%|██████████| 2/2 [00:00<00:00,  3.75it/s, v_num=0, train_loss_step=0.335, val_loss=0.854, train_loss_epoch=1.010]\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 16.17it/s]\u001b[A\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Epoch 1: 100%|██████████| 2/2 [00:00<00:00,  3.26it/s, v_num=0, train_loss_step=0.335, val_loss=0.893, train_loss_epoch=1.010]\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:15,979\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s, v_num=0, train_loss_step=0.335, val_loss=0.893, train_loss_epoch=0.703]\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.335, val_loss=0.893, train_loss_epoch=0.703]\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:16,509\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:17,399\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000009)\u001b[32m [repeated 6x across cluster]\u001b[0m\n",
      "2024-10-22 09:04:17,944\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:18,760\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:19,250\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:20,250\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11:  50%|█████     | 1/2 [00:00<00:00,  1.25it/s, v_num=0, train_loss_step=0.175, val_loss=0.897, train_loss_epoch=0.258]\u001b[32m [repeated 8x across cluster]\u001b[0m\n",
      "Epoch 11: 100%|██████████| 2/2 [00:01<00:00,  1.79it/s, v_num=0, train_loss_step=0.312, val_loss=0.897, train_loss_epoch=0.258]\u001b[32m [repeated 7x across cluster]\u001b[0m\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 7x across cluster]\u001b[0m\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 7x across cluster]\u001b[0m\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 7x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:20,955\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m \u001b[32m [repeated 11x across cluster]\u001b[0m\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  7.84it/s]\u001b[A\u001b[32m [repeated 7x across cluster]\u001b[0m\n",
      "Epoch 11: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s, v_num=0, train_loss_step=0.312, val_loss=0.869, train_loss_epoch=0.258]\u001b[32m [repeated 7x across cluster]\u001b[0m\n",
      "Epoch 11: 100%|██████████| 2/2 [00:01<00:00,  1.27it/s, v_num=0, train_loss_step=0.312, val_loss=0.869, train_loss_epoch=0.203]\u001b[32m [repeated 7x across cluster]\u001b[0m\n",
      "Epoch 12:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.312, val_loss=0.869, train_loss_epoch=0.203]\u001b[32m [repeated 7x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:21,687\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:22,323\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:22,766\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000013)\u001b[32m [repeated 8x across cluster]\u001b[0m\n",
      "2024-10-22 09:04:24,404\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:25,524\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14:  50%|█████     | 1/2 [00:01<00:01,  0.88it/s, v_num=0, train_loss_step=0.131, val_loss=0.841, train_loss_epoch=0.141] \u001b[32m [repeated 6x across cluster]\u001b[0m\n",
      "Epoch 7: 100%|██████████| 2/2 [00:01<00:00,  1.13it/s, v_num=0, train_loss_step=0.368, val_loss=0.836, train_loss_epoch=0.399]\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  7.76it/s]\u001b[A\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Epoch 7: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s, v_num=0, train_loss_step=0.368, val_loss=0.843, train_loss_epoch=0.399]\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Epoch 7: 100%|██████████| 2/2 [00:02<00:00,  0.79it/s, v_num=0, train_loss_step=0.368, val_loss=0.843, train_loss_epoch=0.306]\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Epoch 8:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.368, val_loss=0.843, train_loss_epoch=0.306]\u001b[32m [repeated 5x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:27,188\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:28,260\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000015)\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "2024-10-22 09:04:30,172\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9:  50%|█████     | 1/2 [00:01<00:01,  0.72it/s, v_num=0, train_loss_step=0.216, val_loss=0.889, train_loss_epoch=0.254]\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Epoch 9: 100%|██████████| 2/2 [00:01<00:00,  1.04it/s, v_num=0, train_loss_step=0.322, val_loss=0.889, train_loss_epoch=0.254]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:31,460\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 9x across cluster]\u001b[0m\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  4.73it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Epoch 9: 100%|██████████| 2/2 [00:02<00:00,  0.90it/s, v_num=0, train_loss_step=0.322, val_loss=0.910, train_loss_epoch=0.254]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Epoch 9: 100%|██████████| 2/2 [00:02<00:00,  0.70it/s, v_num=0, train_loss_step=0.322, val_loss=0.910, train_loss_epoch=0.237]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Epoch 16:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.105, val_loss=0.809, train_loss_epoch=0.128]\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:32,873\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:33,534\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:34,844\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d/checkpoint_000011)\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "2024-10-22 09:04:35,472\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18:  50%|█████     | 1/2 [00:01<00:01,  0.98it/s, v_num=0, train_loss_step=0.0962, val_loss=0.781, train_loss_epoch=0.116]\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Epoch 11: 100%|██████████| 2/2 [00:01<00:00,  1.91it/s, v_num=0, train_loss_step=0.263, val_loss=0.889, train_loss_epoch=0.219]\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  9.49it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Epoch 11: 100%|██████████| 2/2 [00:01<00:00,  1.68it/s, v_num=0, train_loss_step=0.263, val_loss=0.861, train_loss_epoch=0.219]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Epoch 11: 100%|██████████| 2/2 [00:01<00:00,  1.19it/s, v_num=0, train_loss_step=0.263, val_loss=0.861, train_loss_epoch=0.146]\u001b[32m [repeated 6x across cluster]\u001b[0m\n",
      "Epoch 12:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.263, val_loss=0.861, train_loss_epoch=0.146]\u001b[32m [repeated 5x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:37,245\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:38,006\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000019)\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "2024-10-22 09:04:40,708\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:04:41,380\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "\u001b[36m(RayTrainWorker pid=24952)\u001b[0m `Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13:  50%|█████     | 1/2 [00:00<00:00,  1.17it/s, v_num=0, train_loss_step=0.118, val_loss=0.849, train_loss_epoch=0.122]\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Epoch 13: 100%|██████████| 2/2 [00:01<00:00,  1.62it/s, v_num=0, train_loss_step=0.0846, val_loss=0.849, train_loss_epoch=0.122]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  7.32it/s]\u001b[A\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Epoch 13: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s, v_num=0, train_loss_step=0.0846, val_loss=0.842, train_loss_epoch=0.122]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
      "Epoch 19: 100%|██████████| 2/2 [00:03<00:00,  0.52it/s, v_num=0, train_loss_step=0.168, val_loss=0.742, train_loss_epoch=0.099]\u001b[32m [repeated 5x across cluster]\u001b[0m\n",
      "Epoch 14:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.0846, val_loss=0.842, train_loss_epoch=0.112]\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:44,176\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15:  50%|█████     | 1/2 [00:01<00:01,  0.64it/s, v_num=0, train_loss_step=0.0923, val_loss=0.839, train_loss_epoch=0.0974]\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "Epoch 15: 100%|██████████| 2/2 [00:02<00:00,  0.94it/s, v_num=0, train_loss_step=0.0867, val_loss=0.839, train_loss_epoch=0.0974]\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 10.51it/s]\u001b[A\n",
      "Epoch 14: 100%|██████████| 2/2 [00:01<00:00,  1.63it/s, v_num=0, train_loss_step=0.126, val_loss=0.839, train_loss_epoch=0.112]\n",
      "Epoch 14: 100%|██████████| 2/2 [00:02<00:00,  0.87it/s, v_num=0, train_loss_step=0.126, val_loss=0.839, train_loss_epoch=0.0974]\n",
      "Epoch 15:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.126, val_loss=0.839, train_loss_epoch=0.0974]        \n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]\u001b[A\n",
      "Epoch 15: 100%|██████████| 2/2 [00:02<00:00,  0.78it/s, v_num=0, train_loss_step=0.0867, val_loss=0.837, train_loss_epoch=0.0974]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d/checkpoint_000015)\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "2024-10-22 09:04:48,312\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15: 100%|██████████| 2/2 [00:03<00:00,  0.54it/s, v_num=0, train_loss_step=0.0867, val_loss=0.837, train_loss_epoch=0.0912]\n",
      "Epoch 16:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.0867, val_loss=0.837, train_loss_epoch=0.0912]        \n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Epoch 16:  50%|█████     | 1/2 [00:02<00:02,  0.35it/s, v_num=0, train_loss_step=0.0792, val_loss=0.837, train_loss_epoch=0.0912]\n",
      "Epoch 16: 100%|██████████| 2/2 [00:03<00:00,  0.61it/s, v_num=0, train_loss_step=0.0703, val_loss=0.837, train_loss_epoch=0.0912]\n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  4.23it/s]\u001b[A\n",
      "Epoch 16: 100%|██████████| 2/2 [00:03<00:00,  0.56it/s, v_num=0, train_loss_step=0.0703, val_loss=0.837, train_loss_epoch=0.0912]\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 2x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:53,245\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16: 100%|██████████| 2/2 [00:04<00:00,  0.41it/s, v_num=0, train_loss_step=0.0703, val_loss=0.837, train_loss_epoch=0.0774]\n",
      "Epoch 17:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.0703, val_loss=0.837, train_loss_epoch=0.0774]        \n",
      "Epoch 17:  50%|█████     | 1/2 [00:01<00:01,  0.90it/s, v_num=0, train_loss_step=0.0711, val_loss=0.837, train_loss_epoch=0.0774]\n",
      "Epoch 17: 100%|██████████| 2/2 [00:01<00:00,  1.36it/s, v_num=0, train_loss_step=0.156, val_loss=0.837, train_loss_epoch=0.0774] \n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 11.47it/s]\u001b[A\n",
      "Epoch 17: 100%|██████████| 2/2 [00:01<00:00,  1.23it/s, v_num=0, train_loss_step=0.156, val_loss=0.836, train_loss_epoch=0.0774]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d/checkpoint_000017)\u001b[32m [repeated 2x across cluster]\u001b[0m2024-10-22 09:04:56,772\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s, v_num=0, train_loss_step=0.156, val_loss=0.836, train_loss_epoch=0.0882]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.156, val_loss=0.836, train_loss_epoch=0.0882]        \n",
      "Epoch 18:  50%|█████     | 1/2 [00:00<00:00,  1.43it/s, v_num=0, train_loss_step=0.0684, val_loss=0.836, train_loss_epoch=0.0882]\n",
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m \u001b[32m [repeated 3x across cluster]\u001b[0m\n",
      "Epoch 18: 100%|██████████| 2/2 [00:00<00:00,  2.20it/s, v_num=0, train_loss_step=0.064, val_loss=0.836, train_loss_epoch=0.0882] \n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 12.20it/s]\u001b[A\n",
      "Epoch 18: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s, v_num=0, train_loss_step=0.064, val_loss=0.830, train_loss_epoch=0.0882]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:04:58,523\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18: 100%|██████████| 2/2 [00:01<00:00,  1.32it/s, v_num=0, train_loss_step=0.064, val_loss=0.830, train_loss_epoch=0.0675]\n",
      "Epoch 19:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0, train_loss_step=0.064, val_loss=0.830, train_loss_epoch=0.0675]        \n",
      "Epoch 19:  50%|█████     | 1/2 [00:00<00:00,  1.64it/s, v_num=0, train_loss_step=0.0571, val_loss=0.830, train_loss_epoch=0.0675]\n",
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  2.53it/s, v_num=0, train_loss_step=0.120, val_loss=0.830, train_loss_epoch=0.0675] \n",
      "Validation: |          | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\u001b[A\n",
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  2.23it/s, v_num=0, train_loss_step=0.120, val_loss=0.815, train_loss_epoch=0.0675]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:05:00,109\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s, v_num=0, train_loss_step=0.120, val_loss=0.815, train_loss_epoch=0.0697]\n",
      "Epoch 19: 100%|██████████| 2/2 [00:01<00:00,  1.13it/s, v_num=0, train_loss_step=0.120, val_loss=0.815, train_loss_epoch=0.0697]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(RayTrainWorker pid=25062)\u001b[0m `Trainer.fit` stopped: `max_epochs=20` reached.\n",
      "2024-10-22 09:05:01,809\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
      "2024-10-22 09:05:01,823\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37' in 0.0305s.\n",
      "2024-10-22 09:05:01,873\tINFO tune.py:1048 -- Total run time: 83.87 seconds (83.66 seconds for the tuning loop).\n"
     ]
    }
   ],
   "source": [
    "ray.init()\n",
    "\n",
    "scheduler = FIFOScheduler()\n",
    "\n",
    "# Scaling config controls the resources used by Ray\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=1,\n",
    "    use_gpu=False, # change to True if you want to use GPU\n",
    ")\n",
    "\n",
    "# Checkpoint config controls the checkpointing behavior of Ray\n",
    "checkpoint_config = CheckpointConfig(\n",
    "    num_to_keep=1, # number of checkpoints to keep\n",
    "    checkpoint_score_attribute=\"val_loss\", # Save the checkpoint based on this metric\n",
    "    checkpoint_score_order=\"min\", # Save the checkpoint with the lowest metric value\n",
    ")\n",
    "\n",
    "run_config = RunConfig(\n",
    "    checkpoint_config=checkpoint_config,\n",
    "    storage_path=hpopt_save_dir / \"ray_results\", # directory to save the results\n",
    ")\n",
    "\n",
    "ray_trainer = TorchTrainer(\n",
    "    lambda config: train_model(\n",
    "        config, train_dset, val_dset, num_workers, scaler\n",
    "    ),\n",
    "    scaling_config=scaling_config,\n",
    "    run_config=run_config,\n",
    ")\n",
    "\n",
    "search_alg = HyperOptSearch(\n",
    "    n_initial_points=1, # number of random evaluations before tree parzen estimators\n",
    "    random_state_seed=42,\n",
    ")\n",
    "\n",
    "# OptunaSearch is another search algorithm that can be used\n",
    "# search_alg = OptunaSearch() \n",
    "\n",
    "tune_config = tune.TuneConfig(\n",
    "    metric=\"val_loss\",\n",
    "    mode=\"min\",\n",
    "    num_samples=2, # number of trials to run\n",
    "    scheduler=scheduler,\n",
    "    search_alg=search_alg,\n",
    "    trial_dirname_creator=lambda trial: str(trial.trial_id), # shorten filepaths\n",
    "    \n",
    ")\n",
    "\n",
    "tuner = tune.Tuner(\n",
    "    ray_trainer,\n",
    "    param_space={\n",
    "        \"train_loop_config\": search_space,\n",
    "    },\n",
    "    tune_config=tune_config,\n",
    ")\n",
    "\n",
    "# Start the hyperparameter search\n",
    "results = tuner.fit()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameter optimization results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResultGrid<[\n",
       "  Result(\n",
       "    metrics={'train_loss': 0.09904231131076813, 'train_loss_step': 0.16821686923503876, 'val/rmse': 0.8613682389259338, 'val/mae': 0.7006751298904419, 'val_loss': 0.7419552206993103, 'train_loss_epoch': 0.09904231131076813, 'epoch': 19, 'step': 40},\n",
       "    path='/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a',\n",
       "    filesystem='local',\n",
       "    checkpoint=Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000019)\n",
       "  ),\n",
       "  Result(\n",
       "    metrics={'train_loss': 0.06969495117664337, 'train_loss_step': 0.11989812552928925, 'val/rmse': 0.902579665184021, 'val/mae': 0.7176367044448853, 'val_loss': 0.8146500587463379, 'train_loss_epoch': 0.06969495117664337, 'epoch': 19, 'step': 40},\n",
       "    path='/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d',\n",
       "    filesystem='local',\n",
       "    checkpoint=Checkpoint(filesystem=local, path=/home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/d775c15d/checkpoint_000019)\n",
       "  )\n",
       "]>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train_loss</th>\n",
       "      <th>train_loss_step</th>\n",
       "      <th>val/rmse</th>\n",
       "      <th>val/mae</th>\n",
       "      <th>val_loss</th>\n",
       "      <th>train_loss_epoch</th>\n",
       "      <th>epoch</th>\n",
       "      <th>step</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>checkpoint_dir_name</th>\n",
       "      <th>...</th>\n",
       "      <th>pid</th>\n",
       "      <th>hostname</th>\n",
       "      <th>node_ip</th>\n",
       "      <th>time_since_restore</th>\n",
       "      <th>iterations_since_restore</th>\n",
       "      <th>config/train_loop_config/depth</th>\n",
       "      <th>config/train_loop_config/ffn_hidden_dim</th>\n",
       "      <th>config/train_loop_config/ffn_num_layers</th>\n",
       "      <th>config/train_loop_config/message_hidden_dim</th>\n",
       "      <th>logdir</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.099042</td>\n",
       "      <td>0.168217</td>\n",
       "      <td>0.861368</td>\n",
       "      <td>0.700675</td>\n",
       "      <td>0.741955</td>\n",
       "      <td>0.099042</td>\n",
       "      <td>19</td>\n",
       "      <td>40</td>\n",
       "      <td>1729602279</td>\n",
       "      <td>checkpoint_000019</td>\n",
       "      <td>...</td>\n",
       "      <td>24873</td>\n",
       "      <td>Knathan-Laptop</td>\n",
       "      <td>172.31.231.162</td>\n",
       "      <td>49.881516</td>\n",
       "      <td>20</td>\n",
       "      <td>2</td>\n",
       "      <td>2000</td>\n",
       "      <td>2</td>\n",
       "      <td>500</td>\n",
       "      <td>f1a6e41a</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.069695</td>\n",
       "      <td>0.119898</td>\n",
       "      <td>0.902580</td>\n",
       "      <td>0.717637</td>\n",
       "      <td>0.814650</td>\n",
       "      <td>0.069695</td>\n",
       "      <td>19</td>\n",
       "      <td>40</td>\n",
       "      <td>1729602299</td>\n",
       "      <td>checkpoint_000019</td>\n",
       "      <td>...</td>\n",
       "      <td>24953</td>\n",
       "      <td>Knathan-Laptop</td>\n",
       "      <td>172.31.231.162</td>\n",
       "      <td>56.653336</td>\n",
       "      <td>20</td>\n",
       "      <td>2</td>\n",
       "      <td>2200</td>\n",
       "      <td>2</td>\n",
       "      <td>400</td>\n",
       "      <td>d775c15d</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 27 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   train_loss  train_loss_step  val/rmse   val/mae  val_loss  \\\n",
       "0    0.099042         0.168217  0.861368  0.700675  0.741955   \n",
       "1    0.069695         0.119898  0.902580  0.717637  0.814650   \n",
       "\n",
       "   train_loss_epoch  epoch  step   timestamp checkpoint_dir_name  ...    pid  \\\n",
       "0          0.099042     19    40  1729602279   checkpoint_000019  ...  24873   \n",
       "1          0.069695     19    40  1729602299   checkpoint_000019  ...  24953   \n",
       "\n",
       "         hostname         node_ip time_since_restore iterations_since_restore  \\\n",
       "0  Knathan-Laptop  172.31.231.162          49.881516                       20   \n",
       "1  Knathan-Laptop  172.31.231.162          56.653336                       20   \n",
       "\n",
       "   config/train_loop_config/depth  config/train_loop_config/ffn_hidden_dim  \\\n",
       "0                               2                                     2000   \n",
       "1                               2                                     2200   \n",
       "\n",
       "   config/train_loop_config/ffn_num_layers  \\\n",
       "0                                        2   \n",
       "1                                        2   \n",
       "\n",
       "  config/train_loop_config/message_hidden_dim    logdir  \n",
       "0                                         500  f1a6e41a  \n",
       "1                                         400  d775c15d  \n",
       "\n",
       "[2 rows x 27 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# results of all trials\n",
    "result_df = results.get_dataframe()\n",
    "result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'depth': 2,\n",
       " 'ffn_hidden_dim': 2000,\n",
       " 'ffn_num_layers': 2,\n",
       " 'message_hidden_dim': 500}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# best configuration\n",
    "best_result = results.get_best_result()\n",
    "best_config = best_result.config\n",
    "best_config['train_loop_config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best model checkpoint path: /home/knathan/chemprop/examples/hpopt/ray_results/TorchTrainer_2024-10-22_09-03-37/f1a6e41a/checkpoint_000019/checkpoint.ckpt\n"
     ]
    }
   ],
   "source": [
    "# best model checkpoint path\n",
    "best_result = results.get_best_result()\n",
    "best_checkpoint_path = Path(best_result.checkpoint.path) / \"checkpoint.ckpt\"\n",
    "print(f\"Best model checkpoint path: {best_checkpoint_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "ray.shutdown()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
